import os
import time
import numpy as np
import pickle as pkl
import torch
import sys
from tqdm import tqdm 
from torch import optim
from transformers import BertTokenizer
from utils import *
from data import *
from model import BubbleEmbed
import json


class Experiments(object):
    def __init__(self, args):
        super(Experiments, self).__init__()

        self.args = args
        self.tokenizer = self.__load_tokenizer__()
        torch.cuda.empty_cache()
        self.train_loader, self.train_set = load_data(self.args, self.tokenizer, "train")
        self.test_loader, self.test_set = load_data(self.args, self.tokenizer, "test")
        self.model = BubbleEmbed(args, self.tokenizer)
        self.optimizer_pretrain, self.optimizer_projection = self._select_optimizer()
        self._set_device()
        self.exp_setting = (
            str(self.args.pre_train)
            + "_"
            + str(self.args.dataset)
            + "_"
            + str(self.args.expID)
            + "_"
            + str(self.args.epochs)
            + "_"
            + str(self.args.embed_size)
            + "_"
            + str(self.args.batch_size)
            + "_"
            + str(self.args.lr)
            + "_"
            + str(self.args.margin)
            + "_"
            + str(self.args.epsilon)
            + "_"
            + str(self.args.phi)
            + "_"
            + str(self.args.negsamples)
            + "_"
            + str(self.args.alpha)
            + "_"
            + str(self.args.beta)
            + "_"
            + str(self.args.gamma)
            + "_"
            + str(self.args.extra)
        )
        self.setting = {
            "pre_train": self.args.pre_train,
            "dataset": self.args.dataset,
            "expID": self.args.expID,
            "epochs": self.args.epochs,
            "embed_size": self.args.embed_size,
            "batch_size": self.args.batch_size,
            "lr":self.args.lr,
            "margin": self.args.margin,
            "epsilon": self.args.epsilon,
            "size": self.args.phi,
            "alpha": self.args.alpha,
            "beta": self.args.beta,
            "gamma": self.args.gamma,
            "negsamples":self.args.negsamples,
            "minvol":self.args.minvol,
            "contrastive":self.args.contrastive,
            "radratio":self.args.radratio,
            "seed":self.args.seed,
            "theta":self.args.theta,
            "delta": self.args.delta
        }

    def __load_tokenizer__(self):
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        print("Tokenizer Loaded!")
        return tokenizer

    def _select_optimizer(self):
        pre_train_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("pre_train")
                ],
                "weight_decay": 0.0,
            },
        ]
        projection_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("projection")
                ],
                "weight_decay": 0.0,
            },
        ]

        if self.args.optim == "adam":
            optimizer_pretrain = optim.Adam(pre_train_parameters, lr=self.args.lr)
            optimizer_projection = optim.Adam(
                projection_parameters, lr=self.args.lr_projection
            )
        elif self.args.optim == "adamw":
            optimizer_pretrain = optim.AdamW(
                pre_train_parameters, lr=self.args.lr, eps=self.args.eps
            )
            optimizer_projection = optim.AdamW(
                projection_parameters, lr=self.args.lr_projection, eps=self.args.eps
            )

        return optimizer_pretrain, optimizer_projection

    def _set_device(self):
        if self.args.cuda:
            self.model = self.model.cuda()

    def train_one_step(self, it, encode_parent, encode_child, encode_negative_parents): #, parent_ids, child_ids, neg_parent_ids
        self.model.train()
        self.optimizer_pretrain.zero_grad()
        self.optimizer_projection.zero_grad()

        loss,loss_contain,loss_negative,regular_loss,loss_pos_prob,loss_neg_prob = self.model(
            encode_parent, 
            encode_child,
            encode_negative_parents)
        loss.backward()
        self.optimizer_pretrain.step()
        self.optimizer_projection.step()

        return loss,loss_contain,loss_negative,regular_loss,loss_pos_prob,loss_neg_prob

    def train(self,checkpoint=None, save_path=None):
        time_tracker = []
        test_acc = test_mrr = test_wu_p = 0
        old_test_acc = old_test_mrr = old_test_wu_p = 0

        limit=0.6

        if checkpoint:
            self.model.load_state_dict(torch.load(f"{checkpoint}"))

        if save_path is None:
            savedir = os.path.join("../result", self.args.dataset,"model")
            traindir = os.path.join("../result", self.args.dataset,"train")
            if not os.path.exists(savedir):
                os.makedirs(savedir, exist_ok=True)
            if not os.path.exists(traindir):
                os.makedirs(traindir, exist_ok=True)

            save_path = os.path.join("../result", self.args.dataset, "model", f"exp_model_{self.exp_setting}.checkpoint")

        for epoch in tqdm(range(self.args.epochs)):
            epoch_time = time.time()

            train_loss = []
            train_contain_loss = []
            train_negative_loss = []
            train_regular_loss = []
            train_pos_prob_loss = []
            train_neg_prob_loss = []
            for i, (encode_parent, encode_child, encode_negative_parents) in tqdm(enumerate(self.train_loader),total=len(self.train_loader)):
                loss, loss_contain, loss_negative, regular_loss,loss_pos_prob,loss_neg_prob = self.train_one_step(
                    it=i, 
                    encode_parent=encode_parent, 
                    encode_child=encode_child,
                    encode_negative_parents=encode_negative_parents,
                )
                train_loss.append(loss.item())
                train_contain_loss.append(loss_contain.item())
                train_negative_loss.append(loss_negative.item())
                train_regular_loss.append(regular_loss.item())
                train_pos_prob_loss.append(loss_pos_prob.item())
                train_neg_prob_loss.append(loss_neg_prob.item())

            train_loss = np.average(train_loss)
            train_contain_loss = np.average(train_contain_loss)
            train_negative_loss = np.average(train_negative_loss)
            train_regular_loss = np.average(train_regular_loss)
            train_pos_prob_loss = np.average(train_contain_loss)
            train_neg_prob_loss = np.average(train_neg_prob_loss)

            test_metrics = self.predict()
            test_acc = test_metrics["Acc"]
            test_mrr = test_metrics["MRR"]
            test_wu_p = test_metrics["Wu"]

            if(test_acc>old_test_acc or (test_acc==old_test_acc and (old_test_mrr<=test_mrr or old_test_wu_p<=test_wu_p))):
                # Save the best performing model
                torch.save(self.model.state_dict(), save_path)
                old_test_acc = test_acc
                old_test_mrr = test_mrr
                old_test_wu_p = test_wu_p
                
            time_tracker.append(time.time() - epoch_time)

            print(
                "Epoch: {:04d}".format(epoch + 1),
                " train_loss:{:.05f}".format(train_loss),
                "acc:{:.05f}".format(test_acc),
                "mrr:{:.05f}".format(test_mrr),
                "wu_p:{:.05f}".format(test_wu_p),
                " epoch_time:{:.01f}s".format(time.time() - epoch_time),
                " remain_time:{:.01f}s".format(
                    np.mean(time_tracker) * (self.args.epochs - (1 + epoch))
                ),
            )
            
            torch.save(self.model.state_dict(), os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str(epoch)+".checkpoint")) 
            if epoch:
                os.remove(os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str((epoch-1))+".checkpoint"))
            if test_acc>=limit:
                break


    def predict(self, tag=None, path=None):
        print("Prediction starting.....")
        if tag == "test":
            model_path = path if path else f"../result/{self.args.dataset}/model/exp_model_{self.exp_setting}.checkpoint"
            self.model.load_state_dict(torch.load(model_path))
        
        self.model.eval()
        score_list, volume_list, contain_list, distance_list = [], [], [], []
        
        with torch.no_grad():
            if self.args.cuda:
                encode_query = {
                    "input_ids": self.test_set.encode_query["input_ids"].cuda(),
                    "token_type_ids": self.test_set.encode_query["token_type_ids"].cuda(),
                    "attention_mask": self.test_set.encode_query["attention_mask"].cuda()
                }
            else:
                encode_query = self.test_set.encode_query
            query_center, query_delta = self.model.projection_bubble(encode_query)
            candidate_centers, candidate_deltas = [], []
            
            for encode_candidate in self.test_loader:
                candidate_center, candidate_delta = self.model.projection_bubble(encode_candidate)
                candidate_centers.append(candidate_center)
                candidate_deltas.append(candidate_delta)
            
            candidate_center = torch.cat(candidate_centers, dim=0)
            candidate_delta = torch.cat(candidate_deltas, dim=0)
            num_query = query_center.shape[0]
            num_candidate = candidate_center.shape[0]
            
            for i in tqdm(range(num_query), desc="Validation Queries", total=num_query):
                extend_center = query_center[i].expand(num_candidate, -1)
                extend_delta = query_delta[i].expand(num_candidate, -1)

                score, volume = self.model.condition_score(extend_center, extend_delta, candidate_center, candidate_delta)
                is_contain = self.model.is_contain(extend_center, extend_delta, candidate_center, candidate_delta)

                center_distance = self.model.center_distance(extend_center, candidate_center)

                score_list.append(score.unsqueeze(dim=0))
                volume_list.append(volume.unsqueeze(dim=0))
                contain_list.append(is_contain.unsqueeze(dim=0))
                distance_list.append(center_distance.unsqueeze(dim=0))
            
            pred_scores = torch.cat(score_list, dim=0).cpu().numpy()
            pred_volumes = torch.cat(volume_list, dim=0).cpu().numpy()
            pred_contain = torch.cat(contain_list, dim=0).cpu().numpy()
            pred_distances = torch.cat(distance_list, dim=0).cpu().numpy()

            min_distances = np.min(pred_distances, axis=1)[:,np.newaxis]
            max_distances = np.max(pred_distances, axis=1)[:,np.newaxis]

            # Proper min-max normalization followed by inversion (1-x)
            # This ensures closest entity gets 1.0, furthest gets 0.0
            normalized_distances = 1.0 - ((pred_distances - min_distances) / (max_distances - min_distances + 1e-8))

            combined_scores = (1-self.args.theta) * pred_scores + self.args.theta * normalized_distances
            
            ind = np.lexsort((pred_volumes, -combined_scores))
            
            sorted_scores = np.take_along_axis(combined_scores, ind, axis=1)
            print(sorted_scores[:,:5])
            test_metrics = metrics(
                ind, self.test_set.test_gt_id, self.train_set.train_concept_set, 
                self.test_set.path2root, self.test_set.id_concept, self.train_set.id_concept, self.test_set.test_concepts_id,
                sorted_scores
            )            
            
        if(tag=="test"):
            print('acc:{:.05f}'.format(test_metrics["Acc"]),
                'mrr:{:.05f}'.format(test_metrics["MRR"]),
                'wu_p:{:.05f}'.format(test_metrics["Wu"]),
                'mr:{:.05f}'.format(test_metrics["MR"]),
                'prec5:{:.05f}'.format(test_metrics["Prec@5"]),
                'prec10:{:.05f}'.format(test_metrics["Prec@10"]),
                'NDCG:{:.05f}'.format(test_metrics["NDCG"]),
                )
            
            with open(f'../results/{self.args.dataset}/res_{self.exp_setting}.json', 'a+') as f:
                d = vars(self.args)
                expt_details = {
                    "Arguments":d,
                    "Test Metrics":test_metrics
                }
                json.dump(expt_details, f, indent=4)

            return test_metrics
        else:
            return test_metrics